#!/usr/bin/env python3
import os
import numpy as np
import pandas as pd
import sys
sys.path.append('..')
from datasets.argoverse_pickle_loader import read_pkl_data
from datasets.helper import get_lane_direction
from tensorpack import dataflow
import time
import gc
import pickle
import helper
import time
import glob
from argoverse.map_representation.map_api import ArgoverseMap
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader

dataset_path = '~/particle/argoverse/argoverse_forecasting/'
lane_path = '~/particle/TrafficFluids/datasets/'

val_path = os.path.join(dataset_path, 'val', 'clean_data')
train_path = os.path.join(dataset_path, 'train', 'clean_data')
test_path = os.path.join(dataset_path, 'test_obs', 'data')

class ArgoverseTest(dataflow.RNGDataFlow):
    """
    Data flow for argoverse dataset
    """

    def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False,
                 max_car_num: int = 50, freq: int = 10, use_interpolate: bool = False, 
                 lane_path: str = "~/particle/TrafficFluids/datasets", 
                 use_lane: bool = False, use_mask: bool = True):
        if not os.path.exists(file_path):
            raise Exception("Path does not exist.")

        self.afl = ArgoverseForecastingLoader(file_path)
        self.shuffle = shuffle
        self.random_rotation = random_rotation
        self.max_car_num = max_car_num
        self.freq = freq
        self.use_interpolate = use_interpolate
        self.am = ArgoverseMap()
        self.use_mask = use_mask
        self.file_path = file_path
        

    def __iter__(self):
        scene_idxs = np.arange(len(self.afl))
        
        if self.shuffle:
            self.rng.shuffle(scene_idxs)
        
        for scene in scene_idxs:
            
            if self.afl[scene].num_tracks > self.max_car_num:
                continue
            
            data, city = self.afl[scene].seq_df, self.afl[scene].city
            
            lane = np.array([[0., 0., 0.]], dtype=np.float32)
            lane_drct = np.array([[0., 0., 0.]], dtype=np.float32)
            
            
            tstmps = data.TIMESTAMP.unique()
            tstmps.sort()
            
            data = self._filter_imcomplete_data(data, tstmps, 20)
                
            data = self._calc_vel(data, self.freq)
            
            agent = data[data['OBJECT_TYPE'] == 'AGENT']['TRACK_ID'].values[0]
            
            car_mask = np.zeros((self.max_car_num, 1), dtype=np.float32)
            car_mask[:len(data.TRACK_ID.unique())] = 1.0
                
            feat_dict = {'city': city, 
                         'lane': lane, 
                         'lane_norm': lane_drct, 
                         'scene_idx': scene,  
                         'agent_id': agent, 
                         'car_mask': car_mask}
            
            pos_enc = [subdf[['X', 'Y']].values[np.newaxis,:] 
                       for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]
            pos_enc = np.concatenate(pos_enc, axis=0)
            pos_enc = np.insert(pos_enc, 0, axis=1, values=pos_enc[:,0])
            pos_enc = self._expand_dim(pos_enc)
            feat_dict['pos_2s'] = self._expand_particle(pos_enc, self.max_car_num, 0)
            
            vel_enc = [subdf[['vel_x', 'vel_y']].values[np.newaxis,:] 
                       for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]
            vel_enc = np.concatenate(vel_enc, axis=0)
            vel_enc = np.insert(vel_enc, 0, axis=1, values=vel_enc[:,0])
            vel_enc = self._expand_dim(vel_enc)
            feat_dict['vel_2s'] = self._expand_particle(vel_enc, self.max_car_num, 0)
                
            pos = data[data['TIMESTAMP'] == tstmps[19]][['X', 'Y']].values
            pos = self._expand_dim(pos)
            feat_dict['pos0'] = self._expand_particle(pos, self.max_car_num, 0)
            vel = data[data['TIMESTAMP'] == tstmps[19]][['vel_x', 'vel_y']].values
            vel = self._expand_dim(vel)
            feat_dict['vel0'] = self._expand_particle(vel, self.max_car_num, 0)
            track_id =  data[data['TIMESTAMP'] == tstmps[19]]['TRACK_ID'].values
            feat_dict['track_id0'] = self._expand_particle(track_id, self.max_car_num, 0, 'str')
            feat_dict['frame_id0'] = 0
                
            yield feat_dict
    
    def __len__(self):
        return len(glob.glob(os.path.join(self.file_path, '*')))
            
    @classmethod
    def __expand_df_generator(cls, df, city_name):
        ids = df.TRACK_ID.unique()
        tstmps = df.TIMESTAMP.unique()
        for tstmp, sub_df in df.groupby('TIMESTAMP'):
            for idx in ids:
                if not idx in sub_df.TRACK_ID.values:
                    yield pd.DataFrame(dict(TIMESTAMP = [tstmp], TRACK_ID = [idx], X = [np.nan], Y = [np.nan], 
                                       CITY_NAME = [city_name], 
                                            OBJECT_TYPE = [df[df['TRACK_ID'] == idx]['OBJECT_TYPE'].iloc[0]]))
                else:
                    yield df[(df['TIMESTAMP'] == tstmp) & (df['TRACK_ID'] == idx)]

    @classmethod
    def _expand_df(cls, df, city_name):
        return pd.concat(cls.__expand_df_generator(df, city_name), axis=0)


    @classmethod
    def __calc_vel_generator(cls, df, freq=10):
        for idx, subdf in df.groupby('TRACK_ID'):
            sub_df = subdf.copy()
            sub_df[['vel_x', 'vel_y']] = sub_df[['X', 'Y']].diff() * freq
            yield sub_df.iloc[1:, :]

    @classmethod
    def _calc_vel(cls, df, freq=10):
        return pd.concat(cls.__calc_vel_generator(df, freq=freq), axis=0)
    
    @classmethod
    def _expand_dim(cls, ndarr, dtype=np.float32):
        return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)
    
    @classmethod
    def _linear_interpolate_generator(cls, data, col=['X', 'Y']):
        for idx, df in data.groupby('TRACK_ID'):
            sub_df = df.copy()
            sub_df[col] = sub_df[col].interpolate(limit_direction='both')
            yield sub_df
    
    @classmethod
    def _linear_interpolate(cls, data, col=['X', 'Y']):
        return pd.concat(cls._linear_interpolate_generator(data, col), axis=0)
    
    @classmethod
    def _filter_imcomplete_data(cls, data, tstmps, window=20):
        complete_id = list()
        for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'):
            if len(subdf) == window:
                complete_id.append(idx)
        return data[data['TRACK_ID'].isin(complete_id)]
    
    @classmethod
    def _expand_particle(cls, arr, max_num, axis, value_type='int'):
        dummy_shape = list(arr.shape)
        dummy_shape[axis] = max_num - arr.shape[axis]
        dummy = np.zeros(dummy_shape)
        if value_type == 'str':
            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
        return np.concatenate([arr, dummy], axis=axis)


def read_data(file_path=None,
              batch_size=1,
              random_rotation=False,
              repeat=False,
              num_workers=1, 
              **kwargs):

    df = ArgoverseTest(
        file_path=file_path,
        random_rotation=random_rotation,
        shuffle=False,
        **kwargs
    )

    if num_workers > 1:
        df = dataflow.MultiProcessRunnerZMQ(df, num_proc=num_workers)

    df = dataflow.BatchData(df, batch_size=batch_size, use_list=True)

    df.reset_state()
    return df


def read_data_test(file_path, **kwargs):
    return read_data(file_path=file_path,
                     num_workers=1,
                     **kwargs)

    
    
class process_utils(object):
    
    def __init__(self, lane_path):
        with open(os.path.join(lane_path, 'lanes.pkl'), 'rb') as f:
            self.lanes = pickle.load(f)
        with open(os.path.join(lane_path, 'lane_drct.pkl'), 'rb') as f:
            self.lane_drct = pickle.load(f)
            
    @classmethod
    def expand_dim(cls, ndarr, dtype=np.float32):
        return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)
    
    def __look_up_lane_drct(self, lane, city):
        d = np.unique(self.lane_drct[city][np.equal(self.lanes[city], lane).all(axis=1)])
        return d[np.newaxis,:]
    
    def look_up_lane_drct(self, lane, city):
        return np.concatenate([self.__look_up_lane_drct(l, city) for l in lane], axis=0)
    
    @classmethod
    def expand_particle(cls, arr, max_num, axis, value_type='int'):
        dummy_shape = list(arr.shape)
        dummy_shape[axis] = max_num - arr.shape[axis]
        dummy = np.zeros(dummy_shape)
        if value_type == 'str':
            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
        return np.concatenate([arr, dummy], axis=axis)
    

def get_max_min(datas):
    mask = datas['car_mask']
    slicer = mask[0].astype(bool).flatten()
    pos_keys = (['pos0'] + ['pos_2s'])
    max_x = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,0]
                                   .reshape(np.stack(datas[pk]).shape[0], -1), 
                                   axis=-1)[...,np.newaxis]
                            for pk in pos_keys], axis=-1)
    min_x = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,0]
                                   .reshape(np.stack(datas[pk]).shape[0], -1), 
                                   axis=-1)[...,np.newaxis]
                            for pk in pos_keys], axis=-1)
    max_y = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,1]
                                   .reshape(np.stack(datas[pk]).shape[0], -1), 
                                   axis=-1)[...,np.newaxis]
                            for pk in pos_keys], axis=-1)
    min_y = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,1]
                                   .reshape(np.stack(datas[pk]).shape[0], -1), 
                                   axis=-1)[...,np.newaxis]
                            for pk in pos_keys], axis=-1)
    max_x = np.max(max_x, axis=-1)
    max_y = np.max(max_y, axis=-1)
    min_x = np.max(min_x, axis=-1)
    min_y = np.max(min_y, axis=-1)
    return min_x, max_x, min_y, max_y


def process_func(putil, datas, am):
    
    city = datas['city'][0]
    x_min, x_max, y_min, y_max = get_max_min(datas)

    seq_lane_props = am.city_lane_centerlines_dict[city]

    lane_centerlines = []
    lane_directions = []

    # Get lane centerlines which lie within the range of trajectories
    for lane_id, lane_props in seq_lane_props.items():

        lane_cl = lane_props.centerline

        if (
            np.min(lane_cl[:, 0]) < x_max
            and np.min(lane_cl[:, 1]) < y_max
            and np.max(lane_cl[:, 0]) > x_min
            and np.max(lane_cl[:, 1]) > y_min
        ):
            lane_centerlines.append(lane_cl[1:])
            lane_drct = np.diff(lane_cl, axis=0)
            lane_directions.append(lane_drct)
    if len(lane_centerlines) > 0:
        lane = np.concatenate(lane_centerlines, axis=0)
        lane = putil.expand_dim(lane)
        lane_drct = np.concatenate(lane_directions, axis=0)
        lane_drct = putil.expand_dim(lane_drct)[...,:3]

        datas['lane'] = [lane]
        datas['lane_norm'] = [lane_drct]
        return datas
    else:
        return datas
    
    
if __name__ == '__main__':
    am = ArgoverseMap()
    putil = process_utils(lane_path)

    val_dataset = read_pkl_data(val_path, batch_size=1, shuffle=False, repeat=False)
    dataset = read_pkl_data(train_path, batch_size=1, repeat=False, shuffle=True)
    test_dataset = read_data_test(test_path, batch_size=1, repeat=False)
    
    test_num = len(test_dataset)
    batch_start = time.time()
    for i, data in enumerate(test_dataset):
        if i % 1000 == 0:
            batch_end = time.time()
            print("SAVED ============= {} / {} ....... {}".format(i, test_num, batch_end - batch_start))
            batch_start = time.time()
        
        datas = process_func(putil, data, am)
        if datas is None:
            continue
        with open(os.path.join(dataset_path, 'test_obs/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:
            pickle.dump(datas, f)
    
    """
    existing_files = glob.glob(os.path.join(dataset_path, 'train/lane_data'))
    train_num = len(dataset)
    batch_start = time.time()
    for i, data in enumerate(dataset):
        if i % 1000 == 0:
            batch_end = time.time()
            print("SAVED ============= {} / {} ....... {}".format(i, train_num, batch_end - batch_start))
            batch_start = time.time()
        
        file_name = os.path.join(dataset_path, 'train/lane_data', str(data['scene_idx'][0])+'.pkl')
        if file_name in existing_files:
            continue
        datas = process_func(putil, data, am)
        if datas is None:
            continue
        with open(os.path.join(dataset_path, 'train/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:
            pickle.dump(datas, f)
    
    val_num = len(val_dataset)
    batch_start = time.time()
    for i, data in enumerate(val_dataset):
        if i % 1000 == 0:
            batch_end = time.time()
            print("SAVED ============= {} / {} ....... {}".format(i, val_num, batch_end - batch_start))
            batch_start = time.time()
        
        datas = process_func(putil, data, am)
        if datas is None:
            continue
        with open(os.path.join(dataset_path, 'val/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:
            pickle.dump(datas, f)
    """

    